package com.boot.base.common.jdbc;


import com.boot.base.common.enity.TPageInfo;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.lang.Nullable;

import javax.sql.DataSource;
import java.sql.PreparedStatement;
import java.sql.Statement;
import java.util.List;

public class TJdbcTemplate extends JdbcTemplate {

    public TJdbcTemplate(DataSource dataSource) {
        super(dataSource);
    }

    @Override
    public <T> List<T> query(String sql, RowMapper<T> rowMapper) throws DataAccessException {
        return super.query(sql, rowMapper);
    }


    @Override
    public <T> List<T> query(String sql, @Nullable Object[] args, RowMapper<T> rowMapper) throws DataAccessException {
        try {
            return super.query(sql, args, rowMapper);
        } catch (DataAccessException e) {
            e.printStackTrace();
        }
        return null;
    }

    public <T> T getObject(String sql, Class<T> clz, Object... args) throws DataAccessException {
        return super.queryForObject(sql, new BeanPropertyRowMapper<>(clz), args);
    }

    public Integer getInt(String sql, Object... args) throws DataAccessException {
        return super.queryForObject(sql, Integer.class, args);
    }

    public String getStr(String sql, Object... args) throws DataAccessException {
        return super.queryForObject(sql, String.class, args);
    }

    public Long getLong(String sql, Object... args) throws DataAccessException {
        return super.queryForObject(sql, Long.class, args);
    }


    public <T> List<T> getList(String sql, Class<T> clz, Object... args) {
        if (clz.getName().contains("java.lang")) {
            return super.queryForList(sql, clz, args);
        } else {
            return super.query(sql, new BeanPropertyRowMapper<>(clz), args);
        }
    }


    /**
     * @param sql
     * @param rowMapper
     * @param pageNum
     * @param pageSize  0在这里指全部数据
     * @param objects
     * @param <T>
     * @return
     */
    public <T> TPageInfo<T> getQueryPage(String sql, RowMapper<T> rowMapper,
                                         int pageNum, int pageSize, Object... objects) {
        TPageInfo<T> page = new TPageInfo<>();

        String countSql = getCountSql(sql);
        Integer total = getInt(countSql, objects);
        page.setTotal(total);

        if (pageSize == 0) {
            pageSize = page.getTotal();
        }
        String selectSql = getPageSelSql(sql, pageNum, pageSize);
        List<T> query = query(selectSql, rowMapper, objects);
        page.setPageNum(pageNum);
        page.setPageSize(pageSize);
        page.setData(query);
        return page;
    }

    /**
     * @param sql
     * @param clz
     * @param pageNum
     * @param pageSize
     * @param objects
     * @param <T>
     * @return
     */
    public <T> TPageInfo<T> getQueryPage(String sql, Class<T> clz, int pageNum, int pageSize, Object... objects) {
        return getQueryPage(sql, new BeanPropertyRowMapper<T>(clz), pageNum, pageSize, objects);
    }


    /**
     * 获取分页sql
     *
     * @param sql
     * @param pageNum
     * @param pageSize
     * @return
     */
    private String getPageSelSql(String sql, Integer pageNum, Integer pageSize) {
        return "select * from (" + sql + ") _aa limit " + ((pageNum - 1) * pageSize) + "," + pageSize;
    }

    /**
     * 获取计数sql
     *
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
        return String.format("select count(*) from ( %s ) _aa ", sql);
    }

    /**
     * 返回改动行数同时返回新增id
     *
     * @param sql
     * @param keyHolder
     * @param args
     * @return
     * @throws DataAccessException
     */
    public int update(String sql, KeyHolder keyHolder, Object... args) throws DataAccessException {
        PreparedStatementCreator preparedStatementCreator = con -> {
            PreparedStatement ps = con.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
            for (int index = 0; index < args.length; index++) {
                ps.setObject(index + 1, args[index]);
            }
            return ps;
        };
        int update = super.update(preparedStatementCreator, keyHolder);
        return update;
    }

}

